# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase

import numpy as np
import torch
from mmengine.config import ConfigDict
from mmengine.structures import InstanceData

from mmdet import *  # noqa
from mmdet.models.dense_heads import SOLOV2Head
from mmdet.structures.mask import BitmapMasks


def _rand_masks(num_items, bboxes, img_w, img_h):
    rng = np.random.RandomState(0)
    masks = np.zeros((num_items, img_h, img_w))
    for i, bbox in enumerate(bboxes):
        bbox = bbox.astype(np.int32)
        mask = (rng.rand(1, bbox[3] - bbox[1], bbox[2] - bbox[0]) >
                0.3).astype(np.int64)
        masks[i:i + 1, bbox[1]:bbox[3], bbox[0]:bbox[2]] = mask
    return BitmapMasks(masks, height=img_h, width=img_w)


def _fake_mask_feature_head():
    mask_feature_head = ConfigDict(
        feat_channels=128,
        start_level=0,
        end_level=3,
        out_channels=256,
        mask_stride=4,
        norm_cfg=dict(type='GN', num_groups=32, requires_grad=True))
    return mask_feature_head


class TestSOLOv2Head(TestCase):

    def test_solov2_head_loss(self):
        """Tests mask head loss when truth is empty and non-empty."""
        s = 256
        img_metas = [{
            'img_shape': (s, s, 3),
            'ori_shape': (s, s, 3),
            'scale_factor': 1,
            'batch_input_shape': (s, s, 3)
        }]

        mask_feature_head = _fake_mask_feature_head()

        mask_head = SOLOV2Head(
            num_classes=4, in_channels=1, mask_feature_head=mask_feature_head)

        # SOLO head expects a multiple levels of features per image
        feats = []
        for i in range(len(mask_head.strides)):
            feats.append(
                torch.rand(1, 1, s // (2**(i + 2)), s // (2**(i + 2))))
        feats = tuple(feats)

        mask_outs = mask_head.forward(feats)

        # Test that empty ground truth encourages the network to
        # predict background
        gt_instances = InstanceData()
        gt_instances.bboxes = torch.empty(0, 4)
        gt_instances.labels = torch.LongTensor([])
        gt_instances.masks = _rand_masks(0, gt_instances.bboxes.numpy(), s, s)

        empty_gt_losses = mask_head.loss_by_feat(
            *mask_outs,
            batch_gt_instances=[gt_instances],
            batch_img_metas=img_metas)
        # When there is no truth, the cls loss should be nonzero but
        # there should be no box loss.
        empty_cls_loss = empty_gt_losses['loss_cls']
        empty_mask_loss = empty_gt_losses['loss_mask']
        self.assertGreater(empty_cls_loss.item(), 0,
                           'cls loss should be non-zero')
        self.assertEqual(
            empty_mask_loss.item(), 0,
            'there should be no mask loss when there are no true mask')

        # When truth is non-empty then both cls and box loss
        # should be nonzero for random inputs
        gt_instances = InstanceData()
        gt_instances.bboxes = torch.Tensor(
            [[23.6667, 23.8757, 238.6326, 151.8874]])
        gt_instances.labels = torch.LongTensor([2])
        gt_instances.masks = _rand_masks(1, gt_instances.bboxes.numpy(), s, s)

        one_gt_losses = mask_head.loss_by_feat(
            *mask_outs,
            batch_gt_instances=[gt_instances],
            batch_img_metas=img_metas)
        onegt_cls_loss = one_gt_losses['loss_cls']
        onegt_mask_loss = one_gt_losses['loss_mask']
        self.assertGreater(onegt_cls_loss.item(), 0,
                           'cls loss should be non-zero')
        self.assertGreater(onegt_mask_loss.item(), 0,
                           'mask loss should be non-zero')

    def test_solov2_head_empty_result(self):
        s = 256
        img_metas = {
            'img_shape': (s, s, 3),
            'ori_shape': (s, s, 3),
            'scale_factor': 1,
            'batch_input_shape': (s, s, 3)
        }

        mask_feature_head = _fake_mask_feature_head()
        mask_head = SOLOV2Head(
            num_classes=4, in_channels=1, mask_feature_head=mask_feature_head)

        kernel_preds = torch.empty(0, 128)
        cls_scores = torch.empty(0, 80)
        mask_feats = torch.empty(0, 16, 16)
        test_cfg = ConfigDict(
            score_thr=0.1,
            mask_thr=0.5,
        )
        results = mask_head._predict_by_feat_single(
            kernel_preds=kernel_preds,
            cls_scores=cls_scores,
            mask_feats=mask_feats,
            img_meta=img_metas,
            cfg=test_cfg)

        self.assertIsInstance(results, InstanceData)
        self.assertEqual(len(results), 0)
